# SPP结构解决多尺度问题
# 对13*13*1024的特征图进行三次卷积后，经过不同大小的池化核池化

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from CSPDarknet53 import cspdarknet  # 导入网络模型
from CSPDarknet53 import conv_block  # 导入标准卷积块

def spp(inputs):

    # 获取网络的三个输出特征层
    feat1, feat2, feat3 = cspdarknet(inputs)

    # 对最后一个输出特征层进行3次卷积
    # [13,13,1024]==>[13,13,512]
    p5 = conv_block(feat3, filters=512, kernel_size=(1,1), strides=1)
    # [13,13,512]==>[13,13,1024]
    p5 = conv_block(p5, filters=1024, kernel_size=(3,3), strides=1)
    # [13,13,1024]==>[13,13,512]
    p5 = conv_block(p5, filters=512, kernel_size=(1,1), strides=1)

    # 经过不同尺度的最大池化后相堆叠
    maxpool1 = layers.MaxPooling2D(pool_size=(13,13), strides=1, padding='same')(p5)
    maxpool2 = layers.MaxPooling2D(pool_size=(9,9), strides=1, padding='same')(p5)
    maxpool3 = layers.MaxPooling2D(pool_size=(5,5), strides=1, padding='same')(p5)
    
    # 四种尺度在通道维度上堆叠[13,13,2048]
    p5 = layers.concatenate([maxpool1, maxpool2, maxpool3, p5])

    # 三次卷积调整通道数
    # [13,13,2048]==>[13,13,512]
    p5 = conv_block(p5, filters=512, kernel_size=(1,1), strides=1)
    # [13,13,512]==>[13,13,1024]
    p5 = conv_block(p5, filters=1024, kernel_size=(3,3), strides=1)
    # [13,13,1024]==>[13,13,512]
    p5 = conv_block(p5, filters=512, kernel_size=(1,1), strides=1)

    return feat1, feat2, p5


# if __name__ == '__main__':
    
#     inputs = keras.Input(shape=[416,416,3])
#     p5 = spp(inputs)
#     print(p5.shape)